# -*- coding: utf-8 -*-
# @Time    : 2023/3/18 11:23
# @Author  : xiehou
# @File    : _utils.py
# @Software: PyCharm
from torch.utils.data import Dataset
import torch
from torch import nn


class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)


def bias_loss(weights=None, device='cuda'):
    if weights is not None:
        weights = torch.FloatTensor(weights).to(device)
    cross_en = nn.CrossEntropyLoss(weight=weights)
    return lambda pred, target: cross_en(pred.view(-1, pred.size()[-1]), target.view(-1))
